// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_OUTPUTADDRESSGENERATOR_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_OUTPUTADDRESSGENERATOR_H_

#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"

template <typename DTYPE, int NROWS, int WIDTH>
SC_MODULE(OutputAddressGenerator) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<VectorParams> CCS_INIT_S1(paramsIn);
  Connections::In<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(scaleUnitOutput);
  Connections::In<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(softmaxUnitOutput);
  Connections::In<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(layerNormUnitOutput);

  Connections::Out<int> CCS_INIT_S1(outputAddress);
  Connections::Out<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(outputData);

  Connections::SyncOut CCS_INIT_S1(done);

  SC_CTOR(OutputAddressGenerator) {
    SC_THREAD(run);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void run() {
    paramsIn.Reset();
    scaleUnitOutput.Reset();
    softmaxUnitOutput.Reset();
    layerNormUnitOutput.Reset();
    outputAddress.Reset();
    outputData.Reset();
    done.Reset();

    wait();

    while (true) {
      VectorParams params = paramsIn.Pop();

      if (params.SOFTMAX) {
#pragma hls_pipeline_init_interval 1
        for (int y = 0; y < params.loops[1][params.yLoopIndex[1]]; y++) {
          for (int x = 0; x < params.loops[1][params.xLoopIndex[1]]; x++) {
            int X = params.loops[1][params.xLoopIndex[1]];
            int baseAddress = y * X * NROWS + x * NROWS;

            outputData.Push(softmaxUnitOutput.Pop());
            outputAddress.Push(baseAddress + params.OUTPUT_OFFSET);
          }
        }
      } else if (params.LAYER_NORM) {
#pragma hls_pipeline_init_interval 1
        for (int y = 0; y < params.loops[1][params.yLoopIndex[1]]; y++) {
          for (int x = 0; x < params.loops[1][params.xLoopIndex[1]]; x++) {
            int X = params.loops[1][params.xLoopIndex[1]];
            int baseAddress = y * X * NROWS + x * NROWS;

            outputData.Push(layerNormUnitOutput.Pop());
            outputAddress.Push(baseAddress + params.OUTPUT_OFFSET);
          }
        }
      } else {
        int loop_counters[2][3];
        int loop_bounds[2][3];

#pragma hls_unroll yes
        for (int i = 0; i < 2; i++) {
          for (int j = 0; j < 3; j++) {
            loop_bounds[i][j] = params.loops[i][j];
          }
        }
#pragma hls_pipeline_init_interval 1
        for (loop_counters[0][0] = 0; loop_counters[0][0] < loop_bounds[0][0];
             loop_counters[0][0]++) {
          for (loop_counters[0][1] = 0; loop_counters[0][1] < loop_bounds[0][1];
               loop_counters[0][1]++) {
            for (loop_counters[0][2] = 0;
                 loop_counters[0][2] < loop_bounds[0][2];
                 loop_counters[0][2]++) {
              for (loop_counters[1][0] = 0;
                   loop_counters[1][0] < loop_bounds[1][0];
                   loop_counters[1][0]++) {
                for (loop_counters[1][1] = 0;
                     loop_counters[1][1] < loop_bounds[1][1];
                     loop_counters[1][1]++) {
                  for (loop_counters[1][2] = 0;
                       loop_counters[1][2] < loop_bounds[1][2];
                       loop_counters[1][2]++) {
                    int x0 = loop_counters[1][params.xLoopIndex[1]];
                    int X0 = loop_bounds[1][params.xLoopIndex[1]];
                    int x1 = loop_counters[0][params.xLoopIndex[0]];
                    int X1 = loop_bounds[0][params.xLoopIndex[0]];
                    int x = x1 * X0 + x0;
                    int X = X1 * X0;

                    int y0 = loop_counters[1][params.yLoopIndex[1]];
                    int Y0 = loop_bounds[1][params.yLoopIndex[1]];
                    int y1 = loop_counters[0][params.yLoopIndex[0]];
                    int y = y1 * Y0 + y0;

                    int k0 = loop_counters[1][params.kLoopIndex[1]];
                    int K0 = loop_bounds[1][params.kLoopIndex[1]];
                    int k1 = loop_counters[0][params.kLoopIndex[0]];
                    int K1 = loop_bounds[0][params.kLoopIndex[0]];
                    int k = k1 * K0 * NROWS + k0 * NROWS;
                    int K = K1 * K0 * NROWS;

                    int baseAddress = y * X * K + x * K + k;

                    if (params.SPLIT_HEAD) {
                      int HEAD_SIZE = (1 << params.HEAD_SZ_LG2);
                      baseAddress =
                          (k / HEAD_SIZE) * x * HEAD_SIZE + (k % HEAD_SIZE);
                    }

                    outputData.Push(scaleUnitOutput.Pop());
                    outputAddress.Push(baseAddress + params.OUTPUT_OFFSET);
                  }
                }
              }
            }
          }
        }
      }
      done.SyncPush();
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_OUTPUTADDRESSGENERATOR_H_
